import importlib.util
import math
from pathlib import Path

import numpy as np
import pandas as pd
from linearmodels.iv import IV2SLS


def load_shiftshare_module(root: Path):
    path = root / "analysis" / "19_iv_shiftshare_rollout_absorb.py"
    spec = importlib.util.spec_from_file_location("shiftshare", path)
    mod = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(mod)
    return mod


def zscore(series: pd.Series) -> pd.Series:
    s = pd.to_numeric(series, errors="coerce")
    m = s.mean()
    sd = s.std(ddof=0)
    if pd.isna(sd) or sd <= 0:
        return s * 0
    return (s - m) / sd


def fit_iv_custom(mod, df: pd.DataFrame, y: str, controls: list[str], instr_bases: list[str]) -> IV2SLS:
    cols = [y, "netusoft", "edu_high", "ct", "region", *controls, *instr_bases]
    d = df[cols].dropna().copy()
    if d.empty:
        raise ValueError(f"No rows after dropna for outcome={y}")

    d["netusoft_x_edu"] = d["netusoft"] * d["edu_high"]
    instr_x = []
    for b in instr_bases:
        name = f"{b}_x_edu"
        d[name] = d[b] * d["edu_high"]
        instr_x.append(name)

    model_cols = [y, "netusoft", "netusoft_x_edu", *instr_bases, *instr_x, *controls, "edu_high"]
    r = mod.absorb_two_way(d, model_cols, fe_a="region", fe_b="ct")

    y_r = r[y]
    exog = r[controls + ["edu_high"]]
    endog = r[["netusoft", "netusoft_x_edu"]]
    instr_full = r[instr_bases + instr_x]
    instr = mod.select_instruments_full_rank(exog=exog, instruments=instr_full, min_cols=endog.shape[1])

    iv = IV2SLS(y_r, exog=exog, endog=endog, instruments=instr)
    return iv.fit(cov_type="clustered", clusters=d["region"])


def write_latex_table(path: Path, df: pd.DataFrame) -> None:
    cols = [
        "label",
        "nobs",
        "coef_lowEdu",
        "se_lowEdu",
        "coef_highEdu",
        "coef_deltaHighMinusLow",
        "se_deltaHighMinusLow",
        "fsF_netusoft",
        "fsF_netusoft_x_edu",
        "AR_pvalue",
    ]
    header = [
        "Outcome",
        "N",
        "b(LowEdu)",
        "se(LowEdu)",
        "b(HighEdu)",
        "Delta(High-Low)",
        "se(Delta)",
        "F(net)",
        "F(net$\\times$edu)",
        "AR p",
    ]
    d = df[cols].copy()
    d["nobs"] = d["nobs"].map(lambda x: f"{x:.0f}" if pd.notna(x) else "")
    for c in ["coef_lowEdu", "se_lowEdu", "coef_highEdu", "coef_deltaHighMinusLow", "se_deltaHighMinusLow"]:
        d[c] = d[c].map(lambda x: f"{x:.4f}" if pd.notna(x) else "")
    for c in ["fsF_netusoft", "fsF_netusoft_x_edu"]:
        d[c] = d[c].map(lambda x: f"{x:.2f}" if pd.notna(x) else "")
    d["AR_pvalue"] = d["AR_pvalue"].map(lambda x: f"{x:.4g}" if pd.notna(x) else "")

    out = []
    out.append("\\begin{tabular}{lrrrrrrrrr}")
    out.append("\\toprule")
    out.append(" & ".join(header) + " \\\\")
    out.append("\\midrule")
    for _, row in d.iterrows():
        out.append(" & ".join(str(row[c]) for c in cols) + " \\\\")
    out.append("\\bottomrule")
    out.append("\\end{tabular}")
    path.write_text("\n".join(out) + "\n", encoding="utf-8")


def permute_shiftshare_instruments(root: Path, seed: int, instrument_bases: list[str]) -> pd.DataFrame:
    ss = pd.read_csv(root / "data_external" / "rollout_shiftshare_region_year.csv")
    ss["region"] = ss["region"].astype(str).str.upper().str.strip()
    ss["cntry"] = ss["cntry"].astype(str).str.upper().str.strip()

    base = ss[["region", "cntry", "z_bb_access_base_2016_ct_centered"]].drop_duplicates().copy()
    rng = np.random.default_rng(seed)
    shuffled = []
    for cntry, g in base.groupby("cntry", sort=False):
        exp = g["z_bb_access_base_2016_ct_centered"].to_numpy(copy=True)
        rng.shuffle(exp)
        tmp = g[["region", "cntry"]].copy()
        tmp["z_bb_access_perm"] = exp
        shuffled.append(tmp)
    shuffled = pd.concat(shuffled, ignore_index=True)

    ss = ss.merge(shuffled, on=["region", "cntry"], how="left")
    bases_to_shocks = {
        "z_ss_cov30": "shock_cov30",
        "z_ss_cov100": "shock_cov100",
        "z_ss_fttp": "shock_fttp",
        "z_ss_gbps1": "shock_gbps1",
        "z_ss_vhcn_fx": "shock_vhcn_fx",
        "z_ss_nga": "shock_nga",
        "z_ss_docsis31": "shock_docsis31",
        "z_ss_cov2": "shock_cov2",
    }
    out_cols = ["region", "survey_year"]
    for base in instrument_bases:
        shock = bases_to_shocks.get(base)
        if shock is None or shock not in ss.columns:
            continue
        ss[f"ss_{base}_p"] = ss["z_bb_access_perm"] * ss[shock]
        ss[f"{base}_p"] = zscore(ss[f"ss_{base}_p"])
        out_cols.append(f"{base}_p")

    return ss[out_cols]


def permute_shock_instruments(root: Path, seed: int, instrument_bases: list[str]) -> pd.DataFrame:
    """
    Randomization inference by permuting national rollout shocks across countries within each year.
    This keeps the cross-country distribution of shocks in a given year intact, but breaks the
    mapping between a country's regions and that country's rollout shock.
    """
    ss = pd.read_csv(root / "data_external" / "rollout_shiftshare_region_year.csv")
    ss["region"] = ss["region"].astype(str).str.upper().str.strip()
    ss["cntry"] = ss["cntry"].astype(str).str.upper().str.strip()
    ss["survey_year"] = pd.to_numeric(ss["survey_year"], errors="coerce").astype("Int64")

    bases_to_shocks = {
        "z_ss_cov30": "shock_cov30",
        "z_ss_cov100": "shock_cov100",
        "z_ss_fttp": "shock_fttp",
        "z_ss_gbps1": "shock_gbps1",
        "z_ss_vhcn_fx": "shock_vhcn_fx",
        "z_ss_nga": "shock_nga",
        "z_ss_docsis31": "shock_docsis31",
        "z_ss_cov2": "shock_cov2",
    }
    shock_cols = [bases_to_shocks[b] for b in instrument_bases if b in bases_to_shocks and bases_to_shocks[b] in ss.columns]
    shock_cols = sorted(set(shock_cols))
    if not shock_cols:
        raise ValueError("No shock columns available for shock-permutation randomization inference.")

    # Unique country-year shocks
    ct = ss[["cntry", "survey_year", *shock_cols]].drop_duplicates().copy()
    rng = np.random.default_rng(seed)
    permuted = []
    for year, g in ct.groupby("survey_year", sort=True):
        g = g.copy()
        idx = np.arange(len(g))
        rng.shuffle(idx)
        g_perm = g.iloc[idx].reset_index(drop=True)
        # Keep country labels, but take shocks from shuffled countries.
        out = g[["cntry", "survey_year"]].reset_index(drop=True)
        for col in shock_cols:
            out[f"{col}_perm"] = g_perm[col].to_numpy(copy=True)
        permuted.append(out)
    permuted = pd.concat(permuted, ignore_index=True)

    ss = ss.merge(permuted, on=["cntry", "survey_year"], how="left", validate="many_to_one")

    out_cols = ["region", "survey_year"]
    for base in instrument_bases:
        shock = bases_to_shocks.get(base)
        if shock is None or shock not in ss.columns or f"{shock}_perm" not in ss.columns:
            continue
        ss[f"ss_{base}_sp"] = ss["z_bb_access_base_2016_ct_centered"] * ss[f"{shock}_perm"]
        ss[f"{base}_sp"] = zscore(ss[f"ss_{base}_sp"])
        out_cols.append(f"{base}_sp")

    return ss[out_cols]


def permute_shock_timing_within_country_instruments(root: Path, seed: int, instrument_bases: list[str]) -> pd.DataFrame:
    """
    Randomization inference by permuting the timing of national rollout shocks within each country.
    This preserves each country's shock distribution but breaks the mapping from year -> shock level.
    """
    ss = pd.read_csv(root / "data_external" / "rollout_shiftshare_region_year.csv")
    ss["region"] = ss["region"].astype(str).str.upper().str.strip()
    ss["cntry"] = ss["cntry"].astype(str).str.upper().str.strip()
    ss["survey_year"] = pd.to_numeric(ss["survey_year"], errors="coerce").astype("Int64")

    bases_to_shocks = {
        "z_ss_cov30": "shock_cov30",
        "z_ss_cov100": "shock_cov100",
        "z_ss_fttp": "shock_fttp",
        "z_ss_gbps1": "shock_gbps1",
        "z_ss_vhcn_fx": "shock_vhcn_fx",
        "z_ss_nga": "shock_nga",
        "z_ss_docsis31": "shock_docsis31",
        "z_ss_cov2": "shock_cov2",
    }
    shock_cols = [bases_to_shocks[b] for b in instrument_bases if b in bases_to_shocks and bases_to_shocks[b] in ss.columns]
    shock_cols = sorted(set(shock_cols))
    if not shock_cols:
        raise ValueError("No shock columns available for timing-permutation randomization inference.")

    ct = ss[["cntry", "survey_year", *shock_cols]].drop_duplicates().copy()
    rng = np.random.default_rng(seed)

    permuted = []
    for cntry, g in ct.groupby("cntry", sort=False):
        g = g.copy()
        years = g["survey_year"].to_numpy()
        perm_years = years.copy()
        rng.shuffle(perm_years)
        mapping = {int(y): int(py) for y, py in zip(years, perm_years, strict=True)}

        out = g[["cntry", "survey_year"]].copy()
        for col in shock_cols:
            lookup = g.set_index("survey_year")[col].to_dict()
            out[f"{col}_perm"] = out["survey_year"].map(lambda y: lookup.get(mapping[int(y)], np.nan))
        permuted.append(out)

    permuted = pd.concat(permuted, ignore_index=True)
    ss = ss.merge(permuted, on=["cntry", "survey_year"], how="left", validate="many_to_one")

    out_cols = ["region", "survey_year"]
    for base in instrument_bases:
        shock = bases_to_shocks.get(base)
        if shock is None or shock not in ss.columns or f"{shock}_perm" not in ss.columns:
            continue
        x = pd.to_numeric(ss["z_bb_access_base_2016_ct_centered"] * ss[f"{shock}_perm"], errors="coerce")
        ss[f"{base}_tp"] = zscore(x)
        out_cols.append(f"{base}_tp")

    return ss[out_cols]


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_dir = root / "outputs"
    out_dir.mkdir(parents=True, exist_ok=True)
    paper_tables = root / "paper_joc" / "tables"
    paper_tables.mkdir(parents=True, exist_ok=True)

    mod = load_shiftshare_module(root)
    df = mod.build_dataset(root)

    # Sensitivity: instrument components (keep outcome fixed)
    y = "y_part_index"
    controls = ["agea", "gndr", "hinctnta"]
    all_bases = [c for c in getattr(mod, "INSTRUMENT_BASE_COLS", []) if c in df.columns]
    variants = [
        ("All (enhanced)", all_bases),
        ("Legacy (cov30+cov100+fttp)", ["z_ss_cov30", "z_ss_cov100", "z_ss_fttp"]),
        ("cov30 only", ["z_ss_cov30"]),
        ("cov100 only", ["z_ss_cov100"]),
        ("fttp only", ["z_ss_fttp"]),
        ("NGA only", ["z_ss_nga"]),
        ("VHCN (fixed) only", ["z_ss_vhcn_fx"]),
        ("Gigabit coverage only", ["z_ss_gbps1"]),
        ("DOCSIS 3.1 only", ["z_ss_docsis31"]),
    ]
    rows = []
    for label, bases in variants:
        bases = [b for b in bases if b in df.columns]
        if not bases:
            continue
        res = fit_iv_custom(mod, df, y=y, controls=controls, instr_bases=bases)
        r = mod.summarize_result(res, y=y, label=f"sensitivity: {label}")
        rows.append(r)
    sens = pd.DataFrame(rows)
    (out_dir / "iv_shiftshare_absorb_instrument_sensitivity.csv").write_text(sens.to_csv(index=False), encoding="utf-8")
    (out_dir / "iv_shiftshare_absorb_instrument_sensitivity.md").write_text(mod.to_markdown_table(sens) + "\n", encoding="utf-8")
    write_latex_table(paper_tables / "iv_shiftshare_instrument_sensitivity.tex", sens)

    # Randomization inference on first-stage strength
    n_perm_main = 200  # timing-permutation within country (default in supplement)
    n_perm_other = 50  # auxiliary randomizations (reference only)
    seed0 = 20251223
    # observed first-stage (use same estimation sample as y_part_index)
    obs_bases = all_bases if all_bases else ["z_ss_cov30", "z_ss_cov100", "z_ss_fttp"]
    obs = fit_iv_custom(mod, df, y=y, controls=controls, instr_bases=obs_bases)
    diag = obs.first_stage.diagnostics
    f_obs_net = float(diag.loc["netusoft", "f.stat"])
    f_obs_int = float(diag.loc["netusoft_x_edu", "f.stat"])

    # --- Exposure permutation (kept for reference, but not the default supplement table) ---
    exp_perm_rows = []
    for k in range(n_perm_other):
        ss_p = permute_shiftshare_instruments(root, seed=seed0 + k, instrument_bases=obs_bases)
        dfp = df.merge(ss_p, on=["region", "survey_year"], how="left", validate="many_to_one")
        dfp = dfp.copy()
        for b in obs_bases:
            pcol = f"{b}_p"
            if pcol in dfp.columns:
                dfp[b] = dfp[pcol]
        res_p = fit_iv_custom(mod, dfp, y=y, controls=controls, instr_bases=obs_bases)
        diag_p = res_p.first_stage.diagnostics
        exp_perm_rows.append(
            {
                "perm": k + 1,
                "F_netusoft": float(diag_p.loc["netusoft", "f.stat"]),
                "F_netusoft_x_edu": float(diag_p.loc["netusoft_x_edu", "f.stat"]),
            }
        )
    exp_perm = pd.DataFrame(exp_perm_rows)
    (out_dir / "iv_shiftshare_randomization_exposure_firststage.csv").write_text(exp_perm.to_csv(index=False), encoding="utf-8")

    p_net_exp = (int((exp_perm["F_netusoft"] >= f_obs_net).sum()) + 1) / (n_perm_other + 1)
    p_int_exp = (int((exp_perm["F_netusoft_x_edu"] >= f_obs_int).sum()) + 1) / (n_perm_other + 1)
    exp_summary = pd.DataFrame(
        [
            {
                "label": "randomization (exposure-permutation): observed first-stage F",
                "nobs": int(obs.nobs),
                "coef_lowEdu": np.nan,
                "se_lowEdu": np.nan,
                "coef_highEdu": np.nan,
                "coef_deltaHighMinusLow": np.nan,
                "se_deltaHighMinusLow": np.nan,
                "fsF_netusoft": f_obs_net,
                "fsF_netusoft_x_edu": f_obs_int,
                "AR_pvalue": np.nan,
            },
            {
                "label": f"randomization (exposure-permutation): permuted F p-values (n={n_perm_other})",
                "nobs": int(obs.nobs),
                "coef_lowEdu": np.nan,
                "se_lowEdu": np.nan,
                "coef_highEdu": np.nan,
                "coef_deltaHighMinusLow": np.nan,
                "se_deltaHighMinusLow": np.nan,
                "fsF_netusoft": p_net_exp,
                "fsF_netusoft_x_edu": p_int_exp,
                "AR_pvalue": np.nan,
            },
        ]
    )
    write_latex_table(paper_tables / "iv_shiftshare_randomization_exposure_firststage_summary.tex", exp_summary)

    # --- Shock permutation (reference only) ---
    shock_perm_rows = []
    for k in range(n_perm_other):
        ss_sp = permute_shock_instruments(root, seed=seed0 + 1000 + k, instrument_bases=obs_bases)
        dfp = df.merge(ss_sp, on=["region", "survey_year"], how="left", validate="many_to_one")
        dfp = dfp.copy()
        for b in obs_bases:
            spcol = f"{b}_sp"
            if spcol in dfp.columns:
                dfp[b] = dfp[spcol]
        res_sp = fit_iv_custom(mod, dfp, y=y, controls=controls, instr_bases=obs_bases)
        diag_sp = res_sp.first_stage.diagnostics
        shock_perm_rows.append(
            {
                "perm": k + 1,
                "F_netusoft": float(diag_sp.loc["netusoft", "f.stat"]),
                "F_netusoft_x_edu": float(diag_sp.loc["netusoft_x_edu", "f.stat"]),
            }
        )
    shock_perm = pd.DataFrame(shock_perm_rows)

    (out_dir / "iv_shiftshare_randomization_shock_firststage.csv").write_text(shock_perm.to_csv(index=False), encoding="utf-8")

    p_net_sp = (int((shock_perm["F_netusoft"] >= f_obs_net).sum()) + 1) / (n_perm_other + 1)
    p_int_sp = (int((shock_perm["F_netusoft_x_edu"] >= f_obs_int).sum()) + 1) / (n_perm_other + 1)

    # --- Timing permutation within country (DEFAULT; canonical outputs used by supplement + figure) ---
    timing_rows = []
    for k in range(n_perm_main):
        ss_tp = permute_shock_timing_within_country_instruments(root, seed=seed0 + 2000 + k, instrument_bases=obs_bases)
        dfp = df.merge(ss_tp, on=["region", "survey_year"], how="left", validate="many_to_one")
        dfp = dfp.copy()
        for b in obs_bases:
            tpcol = f"{b}_tp"
            if tpcol in dfp.columns:
                dfp[b] = dfp[tpcol]
        res_tp = fit_iv_custom(mod, dfp, y=y, controls=controls, instr_bases=obs_bases)
        diag_tp = res_tp.first_stage.diagnostics
        timing_rows.append(
            {
                "perm": k + 1,
                "F_netusoft": float(diag_tp.loc["netusoft", "f.stat"]),
                "F_netusoft_x_edu": float(diag_tp.loc["netusoft_x_edu", "f.stat"]),
            }
        )
    timing_perm = pd.DataFrame(timing_rows)
    (out_dir / "iv_shiftshare_randomization_firststage.csv").write_text(timing_perm.to_csv(index=False), encoding="utf-8")

    p_net_tp = (int((timing_perm["F_netusoft"] >= f_obs_net).sum()) + 1) / (n_perm_main + 1)
    p_int_tp = (int((timing_perm["F_netusoft_x_edu"] >= f_obs_int).sum()) + 1) / (n_perm_main + 1)
    summary = pd.DataFrame(
        [
            {
                "label": "randomization (timing-permutation within country): observed first-stage F",
                "nobs": int(obs.nobs),
                "coef_lowEdu": np.nan,
                "se_lowEdu": np.nan,
                "coef_highEdu": np.nan,
                "coef_deltaHighMinusLow": np.nan,
                "se_deltaHighMinusLow": np.nan,
                "fsF_netusoft": f_obs_net,
                "fsF_netusoft_x_edu": f_obs_int,
                "AR_pvalue": np.nan,
            },
            {
                "label": f"randomization (timing-permutation within country): permuted F p-values (n={n_perm_main})",
                "nobs": int(obs.nobs),
                "coef_lowEdu": np.nan,
                "se_lowEdu": np.nan,
                "coef_highEdu": np.nan,
                "coef_deltaHighMinusLow": np.nan,
                "se_deltaHighMinusLow": np.nan,
                "fsF_netusoft": p_net_tp,
                "fsF_netusoft_x_edu": p_int_tp,
                "AR_pvalue": np.nan,
            },
        ]
    )
    (out_dir / "iv_shiftshare_randomization_firststage_summary.csv").write_text(summary.to_csv(index=False), encoding="utf-8")
    write_latex_table(paper_tables / "iv_shiftshare_randomization_firststage_summary.tex", summary)
    (out_dir / "iv_shiftshare_randomization_firststage_summary.txt").write_text(
        "\n".join(
            [
                f"Observed first-stage F (netusoft): {f_obs_net:.3f}",
                f"Observed first-stage F (netusoft_x_edu): {f_obs_int:.3f}",
                f"Timing-permutation p-value (F_netusoft >= obs): {p_net_tp:.4f}",
                f"Timing-permutation p-value (F_netusoft_x_edu >= obs): {p_int_tp:.4f}",
                f"Shock-permutation p-value (F_netusoft >= obs): {p_net_sp:.4f}",
                f"Shock-permutation p-value (F_netusoft_x_edu >= obs): {p_int_sp:.4f}",
                f"Exposure-permutation p-value (F_netusoft >= obs): {p_net_exp:.4f}",
                f"Exposure-permutation p-value (F_netusoft_x_edu >= obs): {p_int_exp:.4f}",
            ]
        )
        + "\n",
        encoding="utf-8",
    )

    print("Wrote stress-test outputs to outputs/ and paper_joc/tables/")


if __name__ == "__main__":
    main()
